from pathlib import Path

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import random
import re
from IPython import embed

abel_folder = "/Users/zhoutang/Dropbox/论文/Zhou Tang/dataset/Abel"
abel_path = Path(abel_folder)
abel_files = list(abel_path.glob('*.xls'))

umass_folder = "/Users/zhoutang/Dropbox/论文/Zhou Tang/dataset/UMASS"
umass_path = Path(umass_folder)
umass_files = list(umass_path.glob('*.xls'))

fields_folder = "/Users/zhoutang/Dropbox/论文/Zhou Tang/dataset/Fields"
fields_path = Path(fields_folder)
fields_files = list(fields_path.glob('*.xls'))

random_folder = "/Users/zhoutang/Dropbox/论文/Zhou Tang/dataset/random"
random_path = Path(random_folder)
random_files = list(random_path.glob('*.xls'))

hirsch_folder = "/Users/zhoutang/Dropbox/论文/Zhou Tang/dataset/hirsch"
hirsch_path = Path(hirsch_folder)
hirsch_files = list(hirsch_path.glob('*.xls'))

def read_citation(files_list):
    citation_dict = {}
    h_index = np.zeros(len(files_list))
    for num, file in enumerate(files_list):
        df = pd.read_excel(file)
        # locate the index of the paper
        ind = df.loc[df.iloc[:, 0] == 'Title'].index.values[0]
        h_index[num] = df.iloc[ind - 2, 1]
        # assure the author name
        author_string = df.columns.values[0]
        author = author_string[author_string.rfind("for:")+4:].strip()
        # print(author)
        # reconstruct the dataset
        colname = df.iloc[ind, :]
        df = df.iloc[ind+1:, :]
        df.columns = colname
        citation_dict[author] = df
    nrow = len(citation_dict)
    colums = pd.period_range("1900", "2020", freq="Y").year
    ncol = len(colums)
    citation = np.zeros((nrow, ncol))
    authors = list(citation_dict.keys())
    for ind, author in enumerate(authors):
        df = citation_dict[author]
        df = df.iloc[:, 21:142].cumsum(axis=1).sum()
        citation[ind, :] = df
    citation = pd.DataFrame(citation, index=authors, columns=colums)
    h_index = pd.DataFrame(h_index, index=authors, columns=["H index"])
    return citation, h_index


abel_citation, abel_h_index = read_citation(abel_files)
umass_citation, umass_h_index = read_citation(umass_files)
fields_citation, fields_h_index = read_citation(fields_files)
random_citation, random_h_index = read_citation(random_files)
hirsch_citation, hirsch_h_index = read_citation(hirsch_files)


def plot_citation(citation):
    sns.set()
    sns.lineplot(data=citation.iloc[:,60:-1].transpose(), dashes=False)
    plt.show()
    print(citation)

# plot_citation(abel_citation)
# plot_citation(umass_citation)
# plot_citation(fields_citation)
# plot_citation(random_citation)
# plot_citation(hirsch_citation)


def get_acceleration(citation):
    acceleration = citation.diff(axis=1).diff(axis=1)
    return acceleration


abel_acceleration = get_acceleration(abel_citation)
umass_acceleration = get_acceleration(umass_citation)
fields_acceleration = get_acceleration(fields_citation)
random_acceleration = get_acceleration(random_citation)
hirsch_acceleration = get_acceleration(hirsch_citation)


def get_acceleration_sg(citation):
    colums = pd.period_range("1900", "2020", freq="Y").year

    authors = citation.index
    citation = np.array(citation)
    acceleration = (2*citation[:, 0:-5] - citation[:, 1:-4] - 2*citation[:, 2:-3]
                    - citation[:, 3:-2] + 2*citation[:, 4:-1])/7
    acceleration = pd.DataFrame(acceleration, index=authors, columns=colums[5:])
    return acceleration


abel_acceleration_sg = get_acceleration_sg(abel_citation)
umass_acceleration_sg = get_acceleration_sg(umass_citation)
fields_acceleration_sg = get_acceleration_sg(fields_citation)
random_acceleration_sg = get_acceleration_sg(random_citation)
hirsch_acceleration_sg = get_acceleration_sg(hirsch_citation)


def plot_acceleration(acceleration, window=5, start=50, location='best'):
    acceleration = acceleration.rolling(axis=1, window=window).mean()

    sns.set()
    ax = sns.lineplot(data=acceleration.iloc[:, start:-2].transpose(), dashes=False)
    ax.set_ylabel("W1")
    ax.legend(loc=location)
    # plt.savefig("Fields_W1.pdf", format="pdf")
    plt.show()


# plot_acceleration(abel_acceleration)
# plot_acceleration(random_acceleration)
# plot_acceleration(fields_acceleration)
# plot_acceleration(hirsch_acceleration)
# plot_acceleration(umass_acceleration)

# plot_acceleration(fields_acceleration_sg.loc["Tao, Terence"], 1, start=50)

# plot_acceleration(abel_acceleration_sg)
# plot_acceleration(umass_acceleration_sg)
# plot_acceleration(fields_acceleration_sg)
# plot_acceleration(random_acceleration_sg)
# plot_acceleration(hirsch_acceleration_sg)


def save_acceleration(abel_acc, abel_acc_sg, fields_acc, fields_acc_sg):
    sns.set_style("white")
    f, axes = plt.subplots(2, 2)
    sns.lineplot(data=abel_acc.iloc[:, 50:-2].transpose(), dashes=False, ax=axes[0, 0])

    axes[0, 0].set_ylabel("W1")
    axes[0, 0].legend(loc="best", fontsize=8)

    sns.lineplot(data=fields_acc.iloc[:, 100:-2].transpose(), dashes=False, ax=axes[0, 1])
    axes[0, 1].legend(loc="best", fontsize=8)

    sns.lineplot(data=abel_acc_sg.iloc[:, 48:-2].transpose(), dashes=False, ax=axes[1, 0])
    axes[1, 0].set_ylabel("W5")
    axes[1, 0].legend(loc="best", fontsize=8)

    sns.lineplot(data=fields_acc_sg.iloc[:, 98:-2].transpose(), dashes=False, ax=axes[1, 1])
    axes[1, 1].legend(loc="best", fontsize=8)

    plt.savefig("W1_W5.pdf", format="pdf")
    plt.show()

abel_acc =  abel_acceleration.loc[["Milnor, J", "Deligne, P", "Langlands, RP"]]
abel_acc_sg = abel_acceleration_sg.loc[["Milnor, J", "Deligne, P", "Langlands, RP"]]
fields_acc = fields_acceleration.loc[["Birkar, Caucher", "Hairer, Martin", "Bhargava, Manjul"]]
fields_acc_sg = fields_acceleration_sg.loc[["Birkar, Caucher", "Hairer, Martin", "Bhargava, Manjul"]]

save_acceleration(abel_acc, abel_acc_sg, fields_acc, fields_acc_sg)




# ##################################################################
# compute the mean and std in here
def get_mean_std(citation, acceleration, start=2, end=8):
    mean = [0]*citation.shape[0]
    std = [0]*citation.shape[0]
    start_year = []
    for row in citation.iterrows():
        start_year.append(row[1][row[1] >= 10].index[0])

    for i, year in zip(range(citation.shape[0]), start_year):
        ind = np.logical_and(year+start <= acceleration.columns,  acceleration.columns <= year+end)
        if year+end >= 2020:
            print(acceleration.iloc[i][ind])
            ind = np.logical_and(year + start <= acceleration.columns, acceleration.columns <= 2019)
            mean[i] = np.mean(acceleration.iloc[i][ind])
            std[i] = np.std(acceleration.iloc[i][ind])
        else:
            mean[i] = np.mean(acceleration.iloc[i][ind])
            std[i] = np.std(acceleration.iloc[i][ind])

        # print(year+15)
    mean = pd.DataFrame(mean, index=citation.index, columns=["mean"])
    std = pd.DataFrame(std, index=citation.index, columns=["std"])
    return pd.concat([mean, std], axis=1)


abel_mean_std = get_mean_std(abel_citation, abel_acceleration, start=2, end=10)
umass_mean_std = get_mean_std(umass_citation, umass_acceleration, start=2, end=10)
fields_mean_std = get_mean_std(fields_citation, fields_acceleration, start=2, end=10)
random_mean_std = get_mean_std(random_citation, random_acceleration, start=2, end=10)
hirsch_mean_std = get_mean_std(hirsch_citation, hirsch_acceleration, start=2, end=10)

abel_mean_std_sg = get_mean_std(abel_citation, abel_acceleration_sg, start=2, end=10)
umass_mean_std_sg = get_mean_std(umass_citation, umass_acceleration_sg, start=2, end=10)
fields_mean_std_sg = get_mean_std(fields_citation, fields_acceleration_sg, start=5, end=10)
random_mean_std_sg = get_mean_std(random_citation, random_acceleration_sg, start=2, end=10)
hirsch_mean_std_sg = get_mean_std(hirsch_citation, hirsch_acceleration_sg, start=2, end=10)

table = pd.concat([abel_mean_std, abel_mean_std_sg], axis=1)
table.columns = ["Mean(W)", "SD(W)", "Mean(W5)", "SD(W5)"]
table = table.loc[["Serre, JP", "Atiyah, MF", "Singer, IM", "Lax, PD", "Carleson, L.", "Varadhan, SRS", "Tits, Jozef",
                   "TATE, JT", "Milnor, J", "Szemeredi, E", "Deligne, P", "Sinai, YG", "Nirenberg, L",
                   "Wiles, A.", "Meyer, Yves", "Langlands, RP", "Uhlenbeck, K"]]
# pd.concat([hirsch_mean_std, hirsch_mean_std_sg], axis=1)


# ##################################################################
# ready to plot

import seaborn as sns
abel_mean_std["type"] = "Abel"
umass_mean_std["type"] = "UMass"
fields_mean_std["type"] = "Fields"
random_mean_std["type"] = "Random"
dataset = pd.concat([abel_mean_std, umass_mean_std, fields_mean_std, random_mean_std], axis=0)
dataset = dataset.drop(["Tao, Terence"])

ax = sns.stripplot(x="type", y="mean", data=dataset)
ax.set_ylabel("W")
ax.set_xlabel("Dataset")
plt.savefig("scatter.pdf", format="pdf")
plt.show()



# ##################################################################

def get_x_y(citation, acceleration, fitted_year=3, predicted_year=6):
    w1 = [0]*citation.shape[0]
    w2 = [0]*citation.shape[0]
    start_year = []
    for row in citation.iterrows():
        start_year.append(row[1][row[1] >= 10].index[0])
    # print(start_year)
    for i, year in zip(range(citation.shape[0]), start_year):
        if citation.iloc[i].loc[year] == 0:
            print(acceleration.index[i])
        # w1[i] = np.mean(acceleration.iloc[i].loc[year:year+3])
        # w2[i] = np.mean(acceleration.iloc[i].loc[year+3:year+6])
        w1[i] = np.mean(acceleration.iloc[i].loc[2+year:2+year+fitted_year])
        w2[i] = np.mean(acceleration.iloc[i].loc[2+year+fitted_year:2+year+predicted_year])
        if 2+year+predicted_year-1 >= 2020:
            print("too large")
            print(acceleration.iloc[i])
    return np.array(w1), np.array(w2)


abel_w8, abel_w16 = get_x_y(abel_citation, abel_acceleration)
hirsch_w8, hirsch_w16 = get_x_y(hirsch_citation, hirsch_acceleration)
fields_w8, fields_w16 = get_x_y(fields_citation, fields_acceleration)
umass_w8, umass_w16 = get_x_y(umass_citation, umass_acceleration)
random_w8, random_w16 = get_x_y(random_citation, random_acceleration)

final_w8 = np.concatenate([abel_w8, fields_w8[0:5], fields_w8[6:], umass_w8, random_w8])
final_w16 = np.concatenate([abel_w16, fields_w16[0:5], fields_w16[6:], umass_w16, random_w16])

# abel_w8, abel_w16 = get_x_y(abel_citation, abel_acceleration_sg)
# hirsch_w8, hirsch_w16 = get_x_y(hirsch_citation, hirsch_acceleration_sg)
# fields_w8, fields_w16 = get_x_y(fields_citation, fields_acceleration_sg)
# umass_w8, umass_w16 = get_x_y(umass_citation, umass_acceleration_sg)
# final_w8 = np.concatenate([abel_w8, hirsch_w8[[0,1,2,3,4,5,7,8,9,10]], fields_w8[0:7], fields_w8[8:], umass_w8])
# final_w16 = np.concatenate([abel_w16, hirsch_w16[[0,1,2,3,4,5,7,8,9,10]], fields_w16[0:7], fields_w16[8:], umass_w16])


# final_h_index = np.concatenate([abel_h_index, hirsch_h_index, fields_h_index[0:7], fields_h_index[8:], umass_h_index])

import statsmodels.api as sm
model = sm.OLS(final_w8, final_w16).fit()
model.summary()



plt.scatter(final_w8, final_w16)
# plt.title("scatter plot and fitted line")
plt.xlabel("mean of W for years 1-3")
plt.ylabel("mean of W for years 4-6")

plt.plot(np.arange(40), model.predict(range(40)), linestyle=":")
plt.savefig("regression.pdf", format="pdf")
plt.show()




#######################################
final_acceleration = pd.concat([abel_acceleration, umass_acceleration, fields_acceleration, random_acceleration, hirsch_acceleration])
final_citation = pd.concat([abel_citation, umass_citation, fields_citation, random_citation, hirsch_citation])
correlation = [np.corrcoef(final_acceleration.iloc[:, i], final_citation.iloc[:, i])[0, 1] for i in range(50, 118)]


#######################################
def read_year_paper(files_list):
    publish_count = np.zeros((len(files_list), 120))
    for i, file in enumerate(files_list):
        df = pd.read_excel(file)

        # locate the index of the paper
        ind = df.loc[df.iloc[:, 0] == 'Title'].index.values[0]
        publish_year = df.iloc[ind + 1:, 7]
        for year in publish_year:
            # print(int(str(year)[-4:]))
            publish_count[i, int(str(year)[-4:])-1900] += 1
        # from IPython import embed
        # embed()
    return np.cumsum(publish_count, axis=1)


abel_publish_year = read_year_paper(abel_files)
umass_publish_year = read_year_paper(umass_files)
fields_publish_year = read_year_paper(fields_files)
random_publish_year = read_year_paper(random_files)
hirsch_publish_year = read_year_paper(hirsch_files)



final_publish_year = np.concatenate([abel_publish_year, umass_publish_year, fields_publish_year, random_publish_year, hirsch_publish_year])
correlation2 = [np.corrcoef(final_acceleration.iloc[:, i], final_publish_year[:, i])[0, 1] for i in range(50, 118)]
list(zip(range(1950, 2018), correlation2))

#######################################
hirsch = pd.concat([hirsch_mean_std.iloc[0:5], hirsch_mean_std.iloc[6:]], axis=0)
h_index = pd.DataFrame([79, 88, 77, 75, 91, 110, 68, 66, 70, 73], index=hirsch.index, columns=["h-index"])
df = pd.concat([h_index, hirsch], axis=1)
df = np.array(df)
np.corrcoef(df[:, 0], df[:, 1])


hirsch = pd.concat([hirsch_mean_std_sg.iloc[0:5], hirsch_mean_std_sg.iloc[6:]], axis=0)
h_index = pd.DataFrame([79, 88, 77, 75, 91, 110, 68, 66, 70, 73], index=hirsch.index, columns=["h-index"])
df = pd.concat([h_index, hirsch], axis=1)
df = np.array(df)
np.corrcoef(df[:, 0], df[:, 1])



